# from prompts_report import get_report_evaluation_instruction
import json
import os
import random
from pathlib import Path
import time
import asyncio
from tqdm import tqdm
from openai import OpenAI
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate research reports using different models')
    parser.add_argument('--api-base-url', 
                       default="https://openrouter.ai/api/v1",
                       help='API base URL')
    parser.add_argument('--api-key', 
                       required=True,
                       help='API key for authentication')
    parser.add_argument('--model-to-test-dir',
                       required=True, 
                       help='Directory containing the model outputs to test')
    parser.add_argument('--models',
                       nargs='+',
                       default=["deepseek/deepseek-r1"],
                       help='List of models to evaluate')
    return parser.parse_args()

# Move args parsing to before any usage
args = parse_args()

# Replace hardcoded values with args
API_BASE_URL = args.api_base_url
API_KEY = args.api_key

client = OpenAI(
    api_key=API_KEY,
    base_url=API_BASE_URL,
)

workspace_dir = "/share/project/xiaoxi/WebThinker-main"

test_path = f"{workspace_dir}/data/Glaive/test.json"

naive_rag_dir = f"{workspace_dir}/outputs/Glaive.Qwen2.5-72B-Instruct.naive_rag/markdown.test.3.28,20:55.94"
rag_r1_dir = f"{workspace_dir}/outputs/Glaive.deepseek-reasoner.naive_rag/markdown.test.4.20,18:27.65"
gemini_dir = f"{workspace_dir}/outputs/glaive.Gemini.DeepResearch"
grok3_dir = f"{workspace_dir}/outputs/glaive.Grok3.DeeperSearch"

# model_to_test_dir = "./outputs/glaive.qwq.webthinker/markdown.test.3.27,21:47.41"
model_to_test_dir = args.model_to_test_dir

# 替换硬编码的 MODELS_TO_EVALUATE
MODELS_TO_EVALUATE = args.models

def get_report_evaluation_instruction(question, system_a, system_b, system_c, system_d, system_e):
    return f"""Research Question: {question}

Please objectively evaluate the quality of research articles generated by systems A, B, C, D and E for this question, and provide scores out of 10 for the following criteria:
(1) Overall Comprehensiveness: The report should cover content as comprehensively as possible
(2) Thoroughness of Discussion: Each section should be discussed thoroughly, not just superficially
(3) Factuality: There should be minimal factual errors
(4) Coherence: The discussion should stay focused and relevant to the topic

Notes:
- A satisfactory performance deserves around 5 points, with higher scores for excellence and lower scores for deficiencies
- You should not easily assign scores higher than 8 or lower than 3 unless you provide substantial reasoning.
- You do not need to consider citations in the articles


----------------------------------------------------------
Research article generated by system A:
----------------------------------------------------------

{system_a}

----------------------------------------------------------



----------------------------------------------------------
Research article generated by system B:
----------------------------------------------------------

{system_b}

----------------------------------------------------------



----------------------------------------------------------
Research article generated by system C:
----------------------------------------------------------

{system_c}

----------------------------------------------------------



----------------------------------------------------------
Research article generated by system D:
----------------------------------------------------------

{system_d}

----------------------------------------------------------



----------------------------------------------------------
Research article generated by system E:
----------------------------------------------------------

{system_e}

----------------------------------------------------------



Research Question: {question}

Please objectively evaluate the quality of research articles generated by systems A, B, C, D and E for this question, and provide scores out of 10 for the following criteria:
(1) Overall Comprehensiveness: The report should cover content as comprehensively as possible
(2) Thoroughness of Discussion: Each section should be discussed thoroughly, not just superficially
(3) Factuality: There should be minimal factual errors
(4) Coherence: The discussion should stay focused and relevant to the topic

Notes:
- A satisfactory performance deserves around 5 points, with higher scores for excellence and lower scores for deficiencies
- You should not easily assign scores higher than 8 or lower than 3 unless you provide substantial reasoning.
- You do not need to consider citations in the articles


Please analyze each article and provide the final scores in the following JSON format:

```json
{{
  "System A": {{
    "Overall Comprehensiveness": ,
    "Thoroughness of Discussion": ,
    "Factuality": ,
    "Coherence": 
  }},
  "System B": {{
    "Overall Comprehensiveness": ,
    "Thoroughness of Discussion": ,
    "Factuality": ,
    "Coherence": 
  }},
  "System C": {{
    "Overall Comprehensiveness": ,
    "Thoroughness of Discussion": ,
    "Factuality": ,
    "Coherence": 
  }},
  "System D": {{
    "Overall Comprehensiveness": ,
    "Thoroughness of Discussion": ,
    "Factuality": ,
    "Coherence": 
  }},
  "System E": {{
    "Overall Comprehensiveness": ,
    "Thoroughness of Discussion": ,
    "Factuality": ,
    "Coherence": 
  }}
}}
```
"""

# Function to read markdown file content
def read_md_file(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            content = f.read()
            content = content.split("#### **Works cited**")[0].split("#### Key Citations")[0].strip('\n').strip()
            if '</think>' in content:
                content = content.split("</think>")[1].strip('\n').strip()
            return content
    except FileNotFoundError:
        return None

# Function to read test questions
def read_test_questions(test_path):
    with open(test_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return [item["Question"] for item in data]

# Function to extract scores from evaluation response
def extract_scores(response_text):
    try:
        # Find the JSON block in the response
        start = response_text.find('{')
        end = response_text.rfind('}') + 1
        json_str = response_text[start:end]
        scores = json.loads(json_str)
        
        return scores
    except:
        print("Failed to parse JSON from response")
        print(response_text)
        return None

def create_chat_completion(
    instruction: str,
    retry_limit: int = 5,
    model_name: str = "openai/gpt-4o"
) -> str:
    """Generate a chat completion with retry logic"""
    for attempt in range(retry_limit):
        try:
            # Add headers to handle region restrictions
            response = client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", "content": instruction}],
                max_tokens=8192,
            )
            if not response.choices:
                print(response)
            return response.choices[0].message.content
        except Exception as e:
            print(f"Chat completion error occurred: {e}, Starting retry attempt {attempt + 1}")
            if attempt == retry_limit - 1:
                print(f"Failed after {retry_limit} attempts: {e}")
                raise
            time.sleep(1 * (attempt + 1))
    return None

# 创建一个字典来存储所有模型的结果
all_models_scores = {}
all_models_detailed_scores = {}

# Read test questions before starting evaluation
questions = read_test_questions(test_path)

# 对每个模型运行评估
for current_model in MODELS_TO_EVALUATE:
    print(f"\nEvaluating with {current_model}...")
    
    # 重置评分追踪
    system_scores = {
        "naive_rag": {"Comprehensiveness": [], "Thoroughness": [], "Factuality": [], "Coherence": []},
        "rag_r1": {"Comprehensiveness": [], "Thoroughness": [], "Factuality": [], "Coherence": []},
        "webthinker": {"Comprehensiveness": [], "Thoroughness": [], "Factuality": [], "Coherence": []},
        "gemini": {"Comprehensiveness": [], "Thoroughness": [], "Factuality": [], "Coherence": []},
        "grok3": {"Comprehensiveness": [], "Thoroughness": [], "Factuality": [], "Coherence": []}
    }
    
    detailed_scores = []
    
    # Process each article
    for i in tqdm(range(30)):
        article_num = i + 1
        
        # Read articles from each system
        articles = {
            "naive_rag": read_md_file(os.path.join(naive_rag_dir, f"article_{article_num}.md")),
            "rag_r1": read_md_file(os.path.join(rag_r1_dir, f"article_{article_num}.md")),
            "webthinker": read_md_file(os.path.join(model_to_test_dir, f"article_{article_num}.md")),
            "gemini": read_md_file(os.path.join(gemini_dir, f"article_{article_num}.md")),
            "grok3": read_md_file(os.path.join(grok3_dir, f"article_{article_num}.md"))
        }

        # Check if any article is None
        if any(article is None for article in articles.values()):
            print(f"Article {article_num} is None, skipped...")
            continue
        
        # Randomly assign systems to A,B,C,D,E
        systems = list(articles.keys())
        random.shuffle(systems)
        system_mapping = {f"System {chr(65+i)}": system for i, system in enumerate(systems)}
        
        # Get evaluation instruction
        instruction = get_report_evaluation_instruction(
            question=questions[i],
            system_a=articles[system_mapping["System A"]],
            system_b=articles[system_mapping["System B"]], 
            system_c=articles[system_mapping["System C"]],
            system_d=articles[system_mapping["System D"]],
            system_e=articles[system_mapping["System E"]]
        )

        # Get evaluation from API
        max_retries = 3
        for retry in range(max_retries):
            try:
                response_content = create_chat_completion(instruction, model_name=current_model)
                if response_content:
                    scores = extract_scores(response_content)
                    if not scores and retry < max_retries - 1:
                        print(f"Failed to extract scores, retrying... (attempt {retry + 1})")
                        continue
                        
                    if scores:
                        # 保存当前问题的详细评分
                        question_detail = {
                            "question_id": article_num,
                            "question": questions[i],
                            "scores": {}
                        }
                        
                        # Map scores back to original systems
                        for system_letter, scores_dict in scores.items():
                            original_system = system_mapping[system_letter]
                            system_scores[original_system]["Comprehensiveness"].append(scores_dict["Overall Comprehensiveness"])
                            system_scores[original_system]["Thoroughness"].append(scores_dict["Thoroughness of Discussion"]) 
                            system_scores[original_system]["Factuality"].append(scores_dict["Factuality"])
                            system_scores[original_system]["Coherence"].append(scores_dict["Coherence"])
                            
                            # 为当前问题添加系统评分
                            question_detail["scores"][original_system] = {
                                "Overall Comprehensiveness": scores_dict["Overall Comprehensiveness"],
                                "Thoroughness of Discussion": scores_dict["Thoroughness of Discussion"],
                                "Factuality": scores_dict["Factuality"],
                                "Coherence": scores_dict["Coherence"]
                            }
                        
                        detailed_scores.append(question_detail)
                        break
                        
            except Exception as e:
                if retry == max_retries - 1:
                    print(f"Failed to get response for question {i + 1} after {max_retries} attempts: {str(e)}")
                    break
                print(f"Error on attempt {retry + 1}: {str(e)}, retrying...")
                time.sleep(1 * (retry + 1))
                continue

    # Calculate averages for current model
    final_scores = {}
    for system, scores in system_scores.items():
        final_scores[system] = {
            metric: sum(values)/len(values) 
            for metric, values in scores.items()
        }
    
    # Store results for this model
    model_short_name = current_model.split("/")[-1]
    all_models_scores[model_short_name] = final_scores
    all_models_detailed_scores[model_short_name] = detailed_scores
    
    # Save detailed results for current model
    t = time.localtime()
    timestamp = f"{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.{t.tm_sec}"
    detailed_output_path = os.path.join(model_to_test_dir, f"evaluation_scores_detailed.{model_short_name}.{timestamp}.json")
    with open(detailed_output_path, 'w') as f:
        json.dump(detailed_scores, f, indent=4)
    print(f"Detailed results for {model_short_name} saved to:", detailed_output_path)

# Calculate average scores across models
average_scores = {}
for system in system_scores.keys():
    average_scores[system] = {}
    for metric in ["Comprehensiveness", "Thoroughness", "Factuality", "Coherence"]:
        scores_sum = sum(all_models_scores[model.split("/")[-1]][system][metric] for model in MODELS_TO_EVALUATE)
        average_scores[system][metric] = scores_sum / len(MODELS_TO_EVALUATE)

# Combine all results
combined_results = {
    model.split("/")[-1]: all_models_scores[model.split("/")[-1]]
    for model in MODELS_TO_EVALUATE
}
combined_results["average"] = average_scores

# Save combined results
t = time.localtime()
timestamp = f"{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.{t.tm_sec}"
output_path = os.path.join(model_to_test_dir, f"evaluation_scores.combined.{timestamp}.json")
with open(output_path, 'w') as f:
    json.dump(combined_results, f, indent=4)

print("\nEvaluation complete.")
print("Combined results saved to:", output_path)
print("\nFinal combined results:")
print(json.dumps(combined_results, indent=2))


